# Ryan Copus, Ryan Hübert and Paige Pellaton
# "Trading Diversity? Judicial Diversity and Case Outcomes in Federal Courts"
# American Political Science Review

# File name: chp_apsr_04_randomization.py
# Last revision date: March 24, 2024
# Questions or comments? Contact Ryan Hübert: https://ryanhubert.com/

# What does this script do?
# This script performs the randomization checks.

# Last pre-print execution of this code:
# > Date: March 24, 2024
# > Machine: MacBook Pro 14" (2021 model) with Apple M1 Max chip and 64 GB RAM
# > OS: macOS Sonoma 14.4
# > Python: version 3.10

################################################################################
# Before you run this script...
################################################################################

# You will need to install the h2o package for python, as well as Java.

# First follow these instructions to install h2o:
# https://docs.h2o.ai/h2o/latest-stable/h2o-py/docs/intro.html#installing-h2o-3

# Next, install Java: https://www.java.com/en/download/

################################################################################
# Load packages and set options
################################################################################

import pandas as pd
import os
import h2o
from h2o.estimators.random_forest import H2ORandomForestEstimator
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
import time
import re
import polars, pyarrow

################################################################################
# Directory management
################################################################################

# Define working directory
wdir = re.sub("/[Cc]ode/?","",os.getcwd())

cdir = wdir + "/outs/" + max([x for x in os.listdir(wdir + "/outs") if x[0:3] == "202"])

rdir = cdir + "/randomization"
if not os.path.exists(rdir):
    os.mkdir(rdir)

mmem = input("How much memory do you want to use? ")

################################################################################
## Import and clean the data
################################################################################

dfile = wdir + "/data/chp_apsr_case_data.csv"

# Import case and plaintiff identity data amd then merge together
df = pd.read_csv(dfile, dtype = object)
df = df.merge(pd.read_csv(wdir + "/outs/chp_apsr_plaintiff_identities.csv").loc[:,["OPEN_ID", 'pla_female', 'pla_nonwhite_w']], on="OPEN_ID", how="left")

# Define the subsets of cases for which we will run randomization checks
tmp = pd.read_csv(cdir + "/model_index.csv")
mask = (tmp["Treatment"]=="nontraditional") & (tmp["Outcome"]=="settlement") & (tmp["Plaintiff_Shares"].isnull()) & (tmp["Race_Coding"].isnull()) & (tmp["President_Controls"]==0)
# Do we do our predictions within subsets of the dataset?
# >> For the randomization test, we look within cases heard by Democratic
#    appointees, Republican appointees and Chief Judge Preska of SDNY.
subsets = {}
for p in ["republican", "democrat"]:
    for scales in [0,1]:
        mask1 = mask & (tmp["Party"]==p) & (tmp["SCALES"]==scales)
        tmp1 = pd.read_csv(cdir + "/masks/" + str(tmp.loc[mask1, "Model_ID"].iloc[0]).zfill(3) + ".csv")
        subsets[p + ("_scales" if scales == 1 else "")] = df["OPEN_ID"].isin(tmp1["OPEN_ID"])

subsets["preska"] = ((df["to_drop"].isnull()) | (df["to_drop"] == "L"))
subsets["preska"] = subsets["preska"] & ((df["OPEN_ID"].str.contains("nysd")) & (df["republican"].isin(["1", 1])))
del tmp, tmp1

# Convert some types
df["nontraditional"] = df["nontraditional"].astype(float)
df["traditional"] = df["traditional"].astype(float)

# Remove rows from df that we won't be using
full_mask = (df["OPEN_ID"]=="")
for subset in subsets:
    full_mask = full_mask | subsets[subset]
df = df.loc[full_mask,:]

# Define sets of pretreatment predictors
pvars1 = ["block", "NOS", "SECTION", "ORIGIN", "JURIS", "JURY", "PROSE", "COUNTY_NAME"]
pvars2 = ['def_count', 'oth_count', 'pla_count', 'def_attorney_count', 'oth_attorney_count',
          'pla_attorney_count', 'def_prose_count', 'oth_prose_count', 'pla_prose_count',
          'def_anonymous', 'oth_anonymous', 'pla_anonymous', 'def_repeat_party_count',
          'oth_repeat_party_count', 'pla_repeat_party_count', 'pla_female', 'pla_nonwhite_w']
tvars = ['jid', 'republican', 'democrat', 'nontraditional']

# Reduce the memory impact of the dataframe
df = df.loc[:, ["OPEN_ID"] + list(set(tvars + pvars1 + pvars2))]

# Convert treatment variables to integer
for v in tvars[1:]:
    df[v] = df[v].astype(int)


################################################################################
## Set up for the machine learning algorithms that will be run
################################################################################

# Define the learners
base_learners = {}

base_learners["my_ols"] = H2OGeneralizedLinearEstimator(family="binomial",
                                                        lambda_=0,
                                                        standardize=False,
                                                        nfolds=10,
                                                        keep_cross_validation_predictions=True,
                                                        keep_cross_validation_models=True,
                                                        keep_cross_validation_fold_assignment=True,
                                                        seed=2020)
base_learners["my_lasso"] = H2OGeneralizedLinearEstimator(nfolds=10,
                                                          fold_assignment="Modulo",
                                                          keep_cross_validation_predictions=True,
                                                          keep_cross_validation_models=True,
                                                          keep_cross_validation_fold_assignment=True,
                                                          family="binomial",  # for logistic regression
                                                          alpha=1,  # for LASSO
                                                          nlambdas=100,
                                                          seed=2020)
base_learners["my_rf"] = H2ORandomForestEstimator(nfolds=10,
                                                  fold_assignment="Modulo",
                                                  keep_cross_validation_predictions=True,
                                                  keep_cross_validation_models=True,
                                                  keep_cross_validation_fold_assignment=True,
                                                  ntrees=500,
                                                  seed=2020)

base_learners["my_ensemble"] = H2OGeneralizedLinearEstimator(family="binomial",
                                                             lambda_=0,
                                                             standardize=False,
                                                             non_negative=True,
                                                             nfolds=10,
                                                             keep_cross_validation_predictions=True,
                                                             keep_cross_validation_models=True,
                                                             keep_cross_validation_fold_assignment=True,
                                                             seed=2020)

# Define a function that extracts what we need from the models
def GetStuff(model, vimp_frame, perf_frame, roc_frame, preds_frame, tf = None):

    # Get variable importance
    tmp0 = pd.DataFrame(model.varimp(), columns=['var', 'rel_imp', 'scaled_imp', 'percentage'])
    tmp0["outvar"] = y
    tmp0["model"] = m
    tmp0["algorithm"] = alg
    vimp_frame = pd.concat([vimp_frame, tmp0], axis=0)
    del tmp0

    # Get performance statistics
    tmp0 = [y, round(model.mse(xval=True), 6), round(model.auc(xval=True), 6),
            m, alg, len(dstats),
            len(dstats.loc[dstats[y] == 1, "jid"].value_counts()),
            len(dstats.loc[dstats[y] == 0, "jid"].value_counts())]

    tmp0 = pd.DataFrame([tmp0], columns=["outvar", "mse", "auc", "model", "algorithm", "obs", "tr_judges", "co_judges"])
    perf_frame = pd.concat([perf_frame, tmp0], axis=0)
    del tmp0

    # Get ROC curve
    tmp0 = pd.DataFrame({"fpr": model.model_performance(xval=True).fprs,
                         "tpr": model.model_performance(xval=True).tprs,
                         "outvar": y, "model": m, "algorithm": alg})
    roc_frame = pd.concat([roc_frame, tmp0], axis=0)
    del tmp0

    # Get cross-validation predictions
    with h2o.utils.threading.local_context(polars_enabled=True, datatable_enabled=True):
        model_preds = pd.concat([tf["OPEN_ID"].as_data_frame(),
                                 model.cross_validation_holdout_predictions().as_data_frame()["p1"]], axis=1)


    model_preds = model_preds.rename(columns={"p1": "prediction"})
    model_preds["outvar"] = y
    model_preds["model"] = m #if of is None else m + ">" + m
    model_preds["algorithm"] = alg
    model_preds["pred_type"] = "cv"
    preds_frame = pd.concat([preds_frame, model_preds], axis=0)

    return [vimp_frame, perf_frame, roc_frame, preds_frame, model_preds]

# Consolidate categorical variables that have low counts to improve model performance
def Preprocess(datafr):
    datafr = pd.DataFrame(datafr).assign(train=1)
    # Get rid of categories that appear too infrequently
    datafr["division"] = datafr["block"].str.extract("^([a-z]+[0-9]+)")
    for v in [x for x in pvars1 if x != "block"]:
        tmp = datafr.groupby("division")[v].value_counts().reset_index(name="n")
        tmp = tmp.loc[tmp["n"] <= 20, :].sort_values("n")
        datafr = datafr.merge(tmp, how="left", on=['division', v]).sort_values("n")
        datafr.loc[(~datafr["n"].isnull()), v] = "OTHER_LOW_FREQUENCY"
        datafr = datafr.drop("n", axis=1)
        del tmp
    datafr = datafr.drop("division", axis=1)
    datafr = datafr.drop("train", axis=1)
    return datafr

################################################################################
## Run the models
################################################################################

# Define the response variable for the randomization check:
#   i.e., whether a judge is a nontraditional judge
y = "nontraditional"
for subset in subsets.keys():
    fn = y + "_" + subset + ".csv"

    # Launch cluster and create a h2o training set
    h2o.init(max_mem_size=str(mmem) + "G")
    h2o.no_progress()

    # Create a mask for the cases we want to do predictions for
    print("\n\n" + "=" * 80 + "\nNow running models on [" + subset + "] subset of cases\n" + "=" * 80 + "\n\n")

    roc = pd.DataFrame()
    perf = pd.DataFrame()
    preds = pd.DataFrame()
    vimps = pd.DataFrame()

    for m in ["benchmark", "saturated"]:
        print("\n" + "     >> Running " + m + " model ")

        # Create the training frame and out-of-sample prediction frame (for PDD SCORES ONLY)
        tf = h2o.H2OFrame(Preprocess(df.loc[subsets[subset], :]))

        for v in [y] + pvars1:
            tf[v] = tf[v].asfactor()

        with h2o.utils.threading.local_context(polars_enabled=True, datatable_enabled=True):
            dstats = tf[:, [y, 'jid', 'block']].as_data_frame()

        pvars = ['block']

        if m != "benchmark":
            pvars = pvars1 + pvars2
            if "scales" in subset:
                pvars = [x for x in pvars if x not in ['pla_female', 'pla_nonwhite_w']]

        with h2o.utils.threading.local_context(polars_enabled=True, datatable_enabled=True):
            eft = pd.DataFrame(tf[:, ["OPEN_ID", y]].as_data_frame())

        for alg in base_learners:
            if "my_ens" in alg:
                continue
            h2o.show_progress()
            base_learners[alg].train(x=pvars, y=y, training_frame=tf)
            h2o.no_progress()
            vimps, perf, roc, preds, eft1 = GetStuff(base_learners[alg], vimps, perf, roc, preds, tf)
            eft = eft.merge(eft1.loc[:, ["OPEN_ID", "prediction"]].rename(columns={"prediction": alg}))
            del eft1

        # Estimate the ensemble
        eft = h2o.H2OFrame(eft)
        eft[y] = eft[y].asfactor()
        h2o.show_progress()
        base_learners["my_ensemble"].train(x=[x for x in eft.columns if 'my_' in x], y=y, training_frame=eft)
        h2o.no_progress()
        vimps, perf, roc, preds, eft1 = GetStuff(base_learners["my_ensemble"], vimps, perf, roc, preds, eft)

        del dstats

    for f in ["roc", "perf", "preds", "vimps"]:
        eval(f).to_csv(rdir + "/" + f + "_" + fn, mode="w", index=False, header=True)

    h2o.remove_all()
    h2o.cluster().shutdown()
    del roc, perf, preds
    time.sleep(5)